/*
 * Copyright (c) 2023 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "sec_storage_ipc_stub_inner.h"

#include <securec.h>
#include <stddef.h>

#include "logger.h"
#include "module_common_ipc_stub.h"
#include "parcel.h"
#include "se_module_sec_storage_defines.h"
#include "sec_storage_core.h"
#include "sec_storage_ipc_defines.h"
#include "sec_storage_ipc_stub.h"
#include "sec_storage_ipc_stub_serialization.h"

#define KEY_DATA_LEN 256
#define BUFFER_DATA_LEN 512

ResultCode ProcStorageCmdSetFactoryResetAuthKeyStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
#ifdef ENABLE_FACTORY
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }
    FactoryResetLevel level = LEVEL_USER_WIPE;
    FactoryResetAuthAlgo algo = ALGO_NIST_P256;

    uint8_t keyData[KEY_DATA_LEN] = {0};
    StorageAuthKey key = {.keyData = keyData, .keySize = KEY_DATA_LEN};

    ResultCode ret = SetFactoryResetAuthKeyInputFromBuffer(buffer, &level, &algo, &key);
    if (ret != SUCCESS) {
        LOG_ERROR("SetFactoryResetAuthKeyInputFromBuffer err %u", ret);
        return ret;
    }

    // no need to reply
    buffer->dataSize = 0;
    return StorageSetFactoryResetAuthenticationKey(context->base.channel, level, algo, &key);
#else
    (void)context;
    (void)buffer;
    return NOT_SUPPORT;
#endif
}

ResultCode ProcStorageCmdGetFactoryResetAuthKeyAlgoStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    FactoryResetLevel level = LEVEL_USER_WIPE;
    ResultCode ret = GetFactoryResetAuthKeyAlgoInputFromBuffer(buffer, &level);
    if (ret != SUCCESS) {
        LOG_ERROR("SetFactoryResetAuthKeyInputFromBuffer err %u", ret);
        return ret;
    }
    FactoryResetAuthAlgo algo = ALGO_NIST_P256;
    ret = StorageGetFactoryResetAuthenticationAlgo(context->base.channel, level, &algo);
    if (ret != SUCCESS) {
        LOG_ERROR("StorageGetFactoryResetAuthenticationAlgo err %u", ret);
        return ret;
    }

    return GetFactoryResetAuthKeyAlgoOutputToBuffer(algo, buffer);
}

ResultCode ProcStorageCmdPrepareFactoryResetStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }
    uint8_t nonceData[KEY_DATA_LEN] = {0};
    uint32_t nonceSize = KEY_DATA_LEN;
    ResultCode ret = StoragePrepareFactoryReset(context->base.channel, nonceData, &nonceSize);
    if (ret != SUCCESS) {
        LOG_ERROR("StoragePrepareFactoryReset err %u", ret);
        return ret;
    }
    return PrepareFactoryResetOutputToBuffer(nonceData, nonceSize, buffer);
}

ResultCode ProcStorageCmdProcessFactoryResetStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    FactoryResetLevel level = LEVEL_USER_WIPE;
    uint8_t keyData[KEY_DATA_LEN] = {0};
    uint32_t keySize = KEY_DATA_LEN;
    ResultCode ret = ProcessFactoryResetInputFromBuffer(buffer, &level, keyData, &keySize);
    if (ret != SUCCESS) {
        LOG_ERROR("ProcessFactoryResetInputFromBuffer err %u", ret);
        return ret;
    }

    return StorageProcessFactoryReset(context->base.channel, level, keyData, keySize);
}

ResultCode ProcStorageCmdSetToUserModeStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
#ifdef ENABLE_FACTORY
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    StorageUserModeConf conf = {0};

    ResultCode ret = SetToUserModeInputFromBuffer(buffer, &conf);
    if (ret != SUCCESS) {
        LOG_ERROR("SetToUserModeInputFromBuffer err %u", ret);
        return ret;
    }

    return StorageSetToUserMode(context->base.channel, &conf);
#else
    (void)context;
    (void)buffer;
    return NOT_SUPPORT;
#endif
}

ResultCode ProcStorageCmdGetSlotOperateAlgorithmStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    SlotOperAlgo algo = ALGO_HKDF_SHA256;
    ResultCode ret = GetSlotOperateAlgorithmInputFromBuffer(buffer, &algo);
    if (ret != SUCCESS) {
        LOG_ERROR("GetSlotOperateAlgorithmInputFromBuffer err %u", ret);
        return ret;
    }

    uint32_t available = 0;
    ret = StorageIsSlotOperateAlgorithmSupported(context->base.channel, algo, &available);
    if (ret != SUCCESS) {
        LOG_ERROR("StorageIsSlotOperateAlgorithmSupported err %u", ret);
        return ret;
    }

    return GetSlotOperateAlgorithmOutputToBuffer(available, buffer);
}

ResultCode ProcStorageCmdGetFactoryResetAlgorithmStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    FactoryResetAuthAlgo algo = ALGO_NIST_P256;
    ResultCode ret = GetFactoryResetAlgorithmInputFromBuffer(buffer, &algo);
    if (ret != SUCCESS) {
        LOG_ERROR("GetSlotOperateAlgorithmInputFromBuffer err %u", ret);
        return ret;
    }

    uint32_t available = 0;
    ret = StorageIsFactoryResetAlgorithmSupported(context->base.channel, algo, &available);
    if (ret != SUCCESS) {
        LOG_ERROR("StorageIsFactoryResetAlgorithmSupported err %u", ret);
        return ret;
    }

    return GetFactoryResetAlgorithmOutputToBuffer(available, buffer);
}

ResultCode ProcStorageCmdSetAllSlotsSizeStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
#ifdef ENABLE_FACTORY
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    uint16_t slotSizeArray[MAX_SLOTS_NUM] = {0};
    uint32_t arrayLength = MAX_SLOTS_NUM;
    ResultCode ret = SetSetAllSlotsSizeInputFromBuffer(buffer, slotSizeArray, &arrayLength);
    if (ret != SUCCESS) {
        LOG_ERROR("SetSetAllSlotsSizeInputFromBuffer err %u", ret);
        return ret;
    }
    return StorageSetAllSlotsSize(context->base.channel, slotSizeArray, arrayLength);
#else
    (void)context;
    (void)buffer;
    return NOT_SUPPORT;
#endif
}

ResultCode ProcStorageCmdGetAllSlotsSizeStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }
    uint16_t slotSizeArray[MAX_SLOTS_NUM] = {0};
    uint32_t arrayLength = MAX_SLOTS_NUM;
    ResultCode ret = StorageGetAllSlotsSize(context->base.channel, slotSizeArray, &arrayLength);
    if (ret != SUCCESS) {
        LOG_ERROR("StorageGetAllSlotsSize err %u", ret);
        return ret;
    }
    return GetAllSlotsSizeOutputToBuffer(slotSizeArray, arrayLength, buffer);
}

ResultCode ProcStorageCmdAllocateSlotStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    StorageFileName name = {0};
    StorageSlotAttr slotAttr;
    (void)memset_s(&slotAttr, sizeof(StorageSlotAttr), 0, sizeof(StorageSlotAttr));

    uint8_t keyData[KEY_DATA_LEN] = {0};
    StorageAuthKey key = {.keyData = keyData, .keySize = KEY_DATA_LEN};

    ResultCode ret = AllocateSlotInputFromBuffer(buffer, &name, &slotAttr, &key);
    if (ret != SUCCESS) {
        LOG_ERROR("AllocateSlotInputFromBuffer err %u", ret);
        return ret;
    }

    uint8_t slotId = 0;
    ret = ConvertStorageFileNameToId(context, &name, &slotId);
    if (ret != SUCCESS) {
        LOG_ERROR("ConvertStorageFileNameToId err %u", ret);
        return ret;
    }

    ret = ProcessStorageKeyDerive(context, key.keyData, key.keySize);
    if (ret != SUCCESS) {
        LOG_ERROR("ProcessStorageKeyDerive err %u", ret);
        return ret;
    }
    LOG_INFO("slotId is %u", slotId);
    ret = StorageAllocateSlot(context->base.channel, slotId, &slotAttr, &key);
    (void)memset_s(keyData, KEY_DATA_LEN, 0, KEY_DATA_LEN);
    return ret;
}

ResultCode ProcStorageCmdWriteSlotStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    StorageFileName name = {0};
    uint8_t keyData[KEY_DATA_LEN] = {0};
    StorageAuthKey key = {.keyData = keyData, .keySize = KEY_DATA_LEN};
    StorageDataArea area = {0, 0};

    uint8_t bufferData[BUFFER_DATA_LEN] = {0};
    StorageDataBuffer data = {.bufferData = bufferData, .bufferSize = BUFFER_DATA_LEN};

    ResultCode ret = WriteSlotInputFromBuffer(buffer, &name, &key, &area, &data);
    if (ret != SUCCESS) {
        LOG_ERROR("WriteSlotInputFromBuffer err %u", ret);
        return ret;
    }

    uint8_t slotId = 0;
    ret = ConvertStorageFileNameToId(context, &name, &slotId);
    if (ret != SUCCESS) {
        LOG_ERROR("ConvertStorageFileNameToId err %u", ret);
        return ret;
    }

    ret = ProcessStorageKeyDerive(context, key.keyData, key.keySize);
    if (ret != SUCCESS) {
        LOG_ERROR("ProcessStorageKeyDerive err %u", ret);
        return ret;
    }

    LOG_INFO("slotId is %u", slotId);
    ret = StorageWriteSlot(context->base.channel, slotId, &key, &area, &data);
    (void)memset_s(keyData, KEY_DATA_LEN, 0, KEY_DATA_LEN);
    return ret;
}

ResultCode ProcStorageCmdReadSlotStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    StorageFileName name = {0};
    uint8_t keyData[KEY_DATA_LEN] = {0};
    StorageAuthKey key = {.keyData = keyData, .keySize = KEY_DATA_LEN};
    StorageDataArea area = {0, 0};

    ResultCode ret = ReadSlotInputFromBuffer(buffer, &name, &key, &area);
    if (ret != SUCCESS) {
        LOG_ERROR("ReadSlotInputFromBuffer err %u", ret);
        return ret;
    }

    ret = ProcessStorageKeyDerive(context, key.keyData, key.keySize);
    if (ret != SUCCESS) {
        LOG_ERROR("ProcessStorageKeyDerive err %u", ret);
        return ret;
    }

    uint8_t slotId = 0;
    ret = ConvertStorageFileNameToId(context, &name, &slotId);
    if (ret != SUCCESS) {
        LOG_ERROR("ConvertStorageFileNameToId err %u", ret);
        return ret;
    }

    uint8_t bufferData[BUFFER_DATA_LEN] = {0};
    StorageDataBuffer data = {.bufferData = bufferData, .bufferSize = BUFFER_DATA_LEN};

    ret = StorageReadSlot(context->base.channel, slotId, &key, &area, &data);
    if (ret != SUCCESS) {
        LOG_ERROR("StorageReadSlot err %u", ret);
        return ret;
    }
    (void)memset_s(keyData, KEY_DATA_LEN, 0, KEY_DATA_LEN);
    LOG_INFO("slotId is %u", slotId);
    return ReadSlotOutputToBuffer(&data, buffer);
}

ResultCode ProcStorageCmdFreeSlotStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }
    StorageFileName name = {0};
    uint8_t keyData[KEY_DATA_LEN] = {0};
    StorageAuthKey key = {.keyData = keyData, .keySize = KEY_DATA_LEN};

    ResultCode ret = FreeSlotInputFromBuffer(buffer, &name, &key);
    if (ret != SUCCESS) {
        LOG_ERROR("FreeSlotInputFromBuffer err %u", ret);
        return ret;
    }

    ret = ProcessStorageKeyDerive(context, key.keyData, key.keySize);
    if (ret != SUCCESS) {
        LOG_ERROR("ProcessStorageKeyDerive err %u", ret);
        return ret;
    }

    uint8_t slotId = 0;
    ret = ConvertStorageFileNameToId(context, &name, &slotId);
    if (ret != SUCCESS) {
        LOG_ERROR("ConvertStorageFileNameToId err %u", ret);
        return ret;
    }
    LOG_INFO("slotId is %u", slotId);
    ret = StorageFreeSlot(context->base.channel, slotId, &key);
    (void)memset_s(keyData, KEY_DATA_LEN, 0, KEY_DATA_LEN);
    return ret;
}

ResultCode ProcStorageCmdGetSlotStatusStub(SecStorageContext *context, SharedDataBuffer *buffer)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    StorageFileName name = {0};
    ResultCode ret = GetSlotStatusInputFromBuffer(buffer, &name);
    if (ret != SUCCESS) {
        LOG_ERROR("GetSlotStatusInputFromBuffer err %u", ret);
        return ret;
    }

    uint8_t slotId = 0;
    ret = ConvertStorageFileNameToId(context, &name, &slotId);
    if (ret != SUCCESS) {
        LOG_ERROR("ConvertStorageFileNameToId err %u", ret);
        return ret;
    }

    StorageSlotStatus status;
    (void)memset_s(&status, sizeof(StorageSlotStatus), 0, sizeof(StorageSlotStatus));

    ret = StorageGetSlotStatus(context->base.channel, slotId, &status);
    if (ret != SUCCESS) {
        LOG_ERROR("StorageGetSlotStatus err %u", ret);
        return ret;
    }
    LOG_INFO("slotId is %u, re is %u, we is %u, fe is %u, alloc is %u, algo is %u", slotId, status.readErrCnt,
        status.writeErrCnt, status.freeErrCnt, (uint32_t)status.status, (uint32_t)status.slotAttr.algo);
    return GetSlotStatusOutputToBuffer(&status, buffer);
}

// private impl
ResultCode ProcessStorageKeyDerive(SecStorageContext *context, uint8_t *key, uint32_t len)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    if (context->derive == NULL) {
        LOG_ERROR("derive is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    return context->derive(context->base.sender, key, len);
}

ResultCode ConvertStorageFileNameToId(const SecStorageContext *context, StorageFileName *name, uint8_t *slotId)
{
    if (context == NULL || name == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    if (context->slotIdGetter == NULL) {
        LOG_ERROR("slotIdGetter is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    char input[SEC_STORAGE_FILE_NAME_SIZE_MAX + 1] = {0};
    if (memcpy_s(input, SEC_STORAGE_FILE_NAME_SIZE_MAX, name->handle, SEC_STORAGE_FILE_NAME_SIZE_MAX) != EOK) {
        return MEM_COPY_ERR;
    }

    return context->slotIdGetter(context->base.sender, input, slotId);
}

ResultCode EnableStorageScpProtocol(const SecStorageContext *context)
{
    if (context == NULL) {
        LOG_ERROR("context is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    if (context->bindingKeyGetter == NULL) {
        LOG_ERROR("bindingKeyGetter is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    uint8_t key[BUFFER_DATA_LEN] = {0};
    uint32_t keyLen = BUFFER_DATA_LEN;
    uint32_t keyVersion = 0;

    ResultCode ret = context->bindingKeyGetter(key, &keyLen, &keyVersion);
    if (ret != SUCCESS) {
        LOG_ERROR("GetServiceBindingKey error = %u", ret);
        return ret;
    }

    ret = ChannelEnableSecureChannelProtocol(context->base.channel, key, keyLen, (uint8_t)keyVersion, 1);
    if (ret != SUCCESS) {
        LOG_ERROR("ChannelEnableSecureChannelProtocol failed = %u", ret);
        return ret;
    }

    return ChannelSetSecureChannelProtocolChecker(context->base.channel, DefaultScpEnableChecker);
}